e3565116c1d02cef24764b40a8106c715fa69dce,src/main/java/ml/shifu/shifu/core/dtrain/lr/LogisticRegressionWorker.java,LogisticRegressionWorker,load,#GuaguaWritableAdapter#GuaguaWritableAdapter#WorkerContext#,300
Before Change
// if fixInitialInput = false, we only compare random value with baggingSampleRate to avoid parsing data.
// if fixInitialInput = true, we should use hashcode after parsing.
double baggingSampleRate = this.modelConfig.getBaggingSampleRate();
if(!this.modelConfig.isFixInitialInput() && Double.compare(Math.random(), baggingSampleRate) >= 0) {
// for negative tags, do sampleNegOnly logic
if(modelConfig.getTrain().getSampleNegOnly()) {
if(modelConfig.isRegression() && Double.compare(outputData[0] + 0d, 0d) == 0) {
return;
}
} else {
return;// normal sampling
}
}
// if fixInitialInput = true, we should use hashcode to sample.
long longBaggingSampleRate = Double.valueOf(baggingSampleRate * 100).longValue();
if(modelConfig.isFixInitialInput() && hashcode % 100 >= longBaggingSampleRate) {
// for negative tags, do sampleNegOnly logic
if(modelConfig.getTrain().getSampleNegOnly()) {
if(modelConfig.isRegression() && Double.compare(outputData[0] + 0d, 0d) == 0) {
return;
}
} else {
return;// normal sampling
}
After Change
}
// if only sample negative, no matter bagging or replacement, do sampling here.
if(modelConfig.getTrain().getSampleNegOnly() // sample negative enabled
&& (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain()
.isOneVsAll())) // regression or onevsall
&& Double.compare(outputData[0] + 0.01d, 0d) == 0 // negative record
&& (!this.modelConfig.isFixInitialInput() && Double.compare(Math.random(),
this.modelConfig.getBaggingSampleRate()) >= 0)) {
return;
}
if(modelConfig.getTrain().getSampleNegOnly()// sample negative enabled
&& (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain()
.isOneVsAll()))// regression or onevsall
&& (Double.compare(outputData[0] + 0.01d, 0d) == 0 // negative record
&& this.modelConfig.isFixInitialInput() && hashcode % 100 >= Double.valueOf(
this.modelConfig.getBaggingSampleRate() * 100).longValue())) {
return;
}
Data data = new Data(inputData, outputData, significance);
// up sampling logic, just add more weights while bagging sampling rate is still not changed
if(modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(outputData[0], 1d) == 0) {
// Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoids sample count to 0
data.setSignificance(data.significance * (this.upSampleRng.sample() + 1));
}
boolean isValidation = false;
if(context.getAttachment() != null && context.getAttachment() instanceof Boolean) {
isValidation = (Boolean) context.getAttachment();
}
boolean isInTraining = addDataPairToDataSet(hashcode, data, isValidation);
// do bagging sampling only for training data,
if(isInTraining) {
float subsampleWeights = sampleWeights(outputData[0]);
if(isPositive(outputData[0])) {
this.positiveSelectedTrainCount += subsampleWeights * 1L;
} else {
this.negativeSelectedTrainCount += subsampleWeights * 1L;
}
// set weights to significance, if 0, significance will be 0, that is bagging sampling
data.setSignificance(data.significance * subsampleWeights);
} else {
// for validation data, according bagging sampling logic, we may need to sampling validation data set, while
// validation data set are only used to compute validation error, not to do real sampling is ok.